library(yaml)
library(dplyr)
library(ggplot2)
library(stringr)
library(miceadds)
library(pspearman)
library(lspline)
library(car)
library(dfadjust)
library(corrplot)

### Load Paths
CONFIG     <- yaml.load_file('config_global.yaml')
data_path  <- CONFIG$build$descriptive
build_path <- CONFIG$build$descriptive
source(sprintf("%s/library.r", CONFIG$source$lib))

### Load Data
drivers_data <- sprintf("%s/drivers.rds", data_path) %>% readRDS()

### Main
main <- function() {
  
  # Setup
  country_split   <- split_countries_polar(drivers_data[["Affective polarization"]])
  country_map     <- make_country_map(country_split)
  driver_list     <- names(drivers_data)
  
  # Choose which countries to drop from parsimony
  driver_drop  <- c("Ethnic fractionalization", "Ethnic polarization", "Non-white share",
                    "Elite polarization", "Elite polarization",
                    "Priv. 24-hr TV news (share)", "Priv. 24-hr TV news (count)", "Share getting news online", "Share getting news on social media")
  country_drop <- c("France",                   "France",              "France",
                    "Japan",              "Switzerland",
                    "New Zealand",                 "New Zealand",                 "New Zealand",               "New Zealand")
  drop <- cbind(country_drop, driver_drop)
  
  # Make Plots
  connected_plots <- lapply(driver_list, make_connected_plot,
                            drivers_data, country_map, drop, year_range = c(1965, 2020))
  driver_list     <- driver_list[driver_list != "Affective polarization"]
  unity_plots     <- lapply(driver_list, make_unity_plot,
                            drivers_data, country_split, drop) 
  
  # Correlation matrix
  main_drivers <- c("inequality", "trade", "onlineshare", "tvcount",  "shareforeign", "sharenotwhite", "partisansorting", "elitepolarization")
  slope_data   <- read_csv(sprintf("build/descriptive/slopes_%s.csv", "inequality"))[, 2:3]
  for (d in main_drivers[2:length(main_drivers)]){
    temp <- read_csv(sprintf("build/descriptive/slopes_%s.csv", d))[, 2:3]
    slope_data <- slope_data %>% merge(temp, all = T)
  }
  cors <- cor(slope_data[, 2:ncol(slope_data)], use = "pairwise.complete.obs", method = "spearman")
  cors <- cbind(rownames(cors), cors)
  write_gslab_table(cors, "build/descriptive/cor_table.txt", "<tab:cor_table>")
}

### Main Functions
split_countries_polar <- function(polar) {
  
  rising  <- c()
  falling <- c()
  slopes  <- post_slopes <- c()
  for (c in unique(polar$country)) {
    
    df   <- polar %>% filter(country == c)
    coef <- lm(partisanaffect ~ years, df)$coefficients["years"]
    
    post_coef  <- lm(partisanaffect ~ lspline(years, c(2000)), data = df)$coefficients["lspline(years, c(2000))2"]
    
    if (coef > 0){
      rising  <- c(rising, c)
    } else {
      falling <- c(falling, c)
    }
    slopes      <- c(slopes, coef)
    post_slopes <- c(post_slopes, post_coef)
  }
  
  return(list(rising = rising, falling = falling, slopes = slopes, countries = unique(polar$country), post_slopes = post_slopes))
}


make_country_map <- function(country_split) {
  
  country                                  <- country_split[c(1, 2)] %>% unlist(recursive = FALSE, use.names = FALSE)
  extension                                <- country %>%
                                              sapply(function(c) strsplit(c, " ", fixed = TRUE)[1] %>% 
                                                                 sapply(str_decap) %>% 
                                                                 paste0(collapse = "_"),
                                                     USE.NAMES = FALSE)
  extension[extension == "united_kingdom"] <- "britain"
  name                                     <- country
  name[name == "United States"]            <- "US"
  name[name == "United Kingdom"]           <- "Britain"
  country_map                              <- list(country = country, name = name, extension = extension) %>%
                                              as.data.frame()
  
  return(country_map)
}


make_connected_plot <- function(driver, drivers_data, country_map, drop, year_range = c(1965, 2020)) {
  
  data_all           <- drivers_data[[driver]] %>% mutate(years = as.numeric(as.character(years))) %>% 
                        filter(years %in% c(year_range[1]:year_range[2]))
  names(data_all)[3] <- "value"
  data_all           <- data_all %>% group_by(country) %>% dplyr::mutate(y_min = floor(min(value)))
  data_all           <- data_all %>% group_by(country) %>% dplyr::mutate(y_max = ceiling(max(value)))
  data_all           <- data_all %>% ungroup() %>% mutate(diff_max = round((max(y_max - y_min))))
  
  for (i in 1:nrow(country_map)){
    # Determine asterisks
    temp <- c(F)
    for (d in 1:nrow(drop)){
      if (driver == drop[d, 2]){
        temp <- c(temp, ifelse(as.character(country_map$country[i]) == drop[d, 1], T, F))
      }
    }
    ast <- ifelse(sum(temp) > 0, T, F)

    # Plot
    if (!((driver %in% c("Ethnic fractionalization", "Ethnic polarization", "Non-white share") & as.character(country_map$country[i]) %in% c("France")) | 
          (driver == "Elite polarization" & as.character(country_map$country[i]) %in% c("Switzerland", "Japan")) | 
          (driver %in% c("Priv. 24-hr TV news (share)", "Priv. 24-hr TV news (count)", "Share getting news online", "Share getting news on social media") & as.character(country_map$country[i]) == "New Zealand"))){
      temp  <- data_all %>% filter(as.character(country) == as.character(country_map$country[i]))
      y_min <- temp$y_min[1] - temp$diff_max[1] / 20
      y_max <- temp$y_min[1] + temp$diff_max[1]
      step  <- round((y_max - y_min) / 5)
      
      yticks             <- seq(floor(y_min/step)*step, ceiling(y_max/step)*step, step)
      ylim               <- c(floor(y_min/step)*step, ceiling(y_max/step)*step)
      
      make_connected_plot_i(i, driver, drivers_data, country_map, year_range, ylim, yticks, ast = ast)
    } else {
      make_connected_plot_i(i, driver, drivers_data, country_map, year_range, ylim, yticks, ast = ast, missing = T)
    }
  }
  
  return(NULL) 
}

make_connected_plot_i <- function(i, driver, drivers_data, country_map, year_range = c(1965, 2020), ylim, yticks, ast = F, missing = F) {
  
  data              <- drivers_data[[driver]] %>% filter(country == as.character(country_map[i, "country"]))
  driver_extension  <- names(data)[3]
  names(data)[3]    <- "value"
  name              <- country_map[i, "name"]
  country_extension <- country_map[i, "extension"]
  
  lm_eqn <- function(df){
    m <- dfadjustSE(lm(value ~ years, df))
    return(sprintf("Slope: %.02f\n(%.02f, %.02f)", m$coefficients[2, 1], 
                   m$coefficients[2, 1] - 1.96 * m$coefficients[2, 4], 
                   m$coefficients[2, 1] + 1.96 * m$coefficients[2, 4]))
  }
  
  connected_plot    <- ggplot(data, aes(x = years, y = value)) + 
                       xlim(year_range) + 
                       scale_y_continuous(limits = ylim, breaks = yticks) +
                       labs(title = ifelse(ast, sprintf("%s*", name), name), x = "", y = ifelse(name == "US", driver, ""))
  
  if (!missing){
    connected_plot  <- connected_plot + geom_point(color = ifelse(driver == "Affective polarization", "black", "grey60"))
    connected_plot  <- connected_plot + 
                            geom_smooth(color = ifelse(driver == "Affective polarization", "red4", "black"), method = "lm", formula = y~x, se = F, show.legend = F) + 
                            annotate("text", x = 1966, y = (ylim[2] - ylim[1]) * .9 + ylim[1],
                                     label = lm_eqn(data), colour = "red4", hjust = 0, size = 7.5) 
  } else{
    connected_plot <- connected_plot + annotate("text", y = mean(ylim), x = mean(year_range), label = "Insufficient data")
  }
  
  connected_plot    <- connected_plot + 
                       theme_bw() + 
                        theme(panel.border     = element_blank(), 
                              panel.grid.major = element_blank(),
                              panel.grid.minor = element_blank(),
                              axis.line        = element_line(colour = "black"),
                              legend.position  = "none",
                              plot.title       = element_text(hjust = 0.5, size = 40),
                              axis.title.x     = element_text(size = 30),
                              axis.title.y     = element_text(size = 18),
                              axis.text.x      = element_text(size = 18),
                              axis.text.y      = element_text(size = 18))
 
  ggsave(sprintf("%s/%s_%s_connected_plot.pdf", build_path, country_extension, driver_extension), 
         height = 5.1, width = 5.5)
 
  return(NULL) 
}


make_unity_plot <- function(driver, drivers_data, country_split, drop) {
  
  reg_data       <- drivers_data[[driver]]
  extension      <- names(reg_data)[3]
  names(reg_data)[3] <- "value"
  
  for (d in 1:nrow(drop)){
    if (driver == drop[d, 2]){
      reg_data              <- reg_data %>% dplyr::filter(!(country == drop[d, 1])) 
      country_split$rising  <- country_split$rising[!(country_split$rising == drop[d, 1])]
      country_split$falling <- country_split$falling[!(country_split$falling == drop[d, 1])]
    }
  }
  
  rising <- c()
  for (c in country_split[["rising"]]){
    temp <- reg_data %>% filter(country == !!c)
    if (nrow(temp) > 1){
      fit <- summary(lm(value ~ years, data = temp))
      rising <- c(rising, fit$coefficients[nrow(fit$coefficients), 1])
    } 
  }
  falling <- c()
  for (c in country_split[["falling"]]){
    temp <- reg_data %>% filter(country == !!c)
    if (nrow(temp) > 1){
      fit <- summary(lm(value ~ years, data = temp))
      falling <- c(falling, fit$coefficients[nrow(fit$coefficients), 1])
    }
  }
  
  slopes <- c(country_split[["slopes"]][match(country_split[["rising"]],  country_split[["countries"]])],
              country_split[["slopes"]][match(country_split[["falling"]], country_split[["countries"]])])
  out <- spearman.test(slopes, c(rising, falling), alternative = "two.sided",
                       approximation = "exact")
  
  # Save driver slopes for correlation matrix 
  slopes_out <- cbind(c(country_split[["rising"]], country_split[["falling"]]),
                      c(rising, falling))
  colnames(slopes_out) <- c("country", driver)
  write.csv(slopes_out, sprintf("build/descriptive/slopes_%s.csv", extension))
  
  make_plot(country_split, slopes, rising, falling, driver, extension)
  
  # Post-2000 test for Internet Penetration
  if (driver == "Internet penetration"){
    temp_split   <- country_split
    temp_split$rising  <- temp_split$countries[temp_split$post_slopes > 0]
    temp_split$falling <- temp_split$countries[temp_split$post_slopes < 0]
    
    
    rising <- c()
    for (c in temp_split[["rising"]]){
      temp <- reg_data %>% filter(country == !!c)
      if (nrow(temp) > 1){
        fit <- summary(lm(value ~ lspline(years, c(2000)), data = temp))
        rising <- c(rising, fit$coefficients[nrow(fit$coefficients), 1])
      } else {
        temp_split[["rising"]] <- temp_split[["rising"]][temp_split[["rising"]] != c]
      }
    }
    falling <- c()
    for (c in temp_split[["falling"]]){
      temp <- reg_data %>% filter(country == !!c)
      if (nrow(temp) > 1){
        fit <- summary(lm(value ~ lspline(years, c(2000)), data = temp))
        falling <- c(falling, fit$coefficients[nrow(fit$coefficients), 1])
      } else {
        temp_split[["falling"]] <- temp_split[["falling"]][temp_split[["falling"]] != c]
      }
    }
    
    slopes <- c(temp_split[["post_slopes"]][match(temp_split[["rising"]],  temp_split[["countries"]])],
                temp_split[["post_slopes"]][match(temp_split[["falling"]], temp_split[["countries"]])])
    out <- spearman.test(slopes, c(rising, falling), alternative = "two.sided",
                         approximation = "exact")

    temp_numb_pos   <- sum(rising > 0) + sum(falling > 0)
    temp_numb_count <- length(c(temp_split$rising, temp_split$falling))
    tab <- rbind(c("Number increasing",  temp_numb_pos),
                 c("Number countries",   temp_numb_count),
                 c("Spearman's rank corr", out$estimate),
                 c("Spearman's p-value", out$p.value))
    
    write_gslab_table(tab, "build/descriptive/internet_post_2000_tests.txt", "<tab:internet_post_2000_tests>")
  }

  return(NULL)
}


make_plot <- function(country_split, slopes, rising, falling, driver, extension){
  library(countrycode)
  plot_data <- data.frame(cbind(c(country_split[["rising"]], country_split[["falling"]]),
                                slopes,
                                c(rising, falling)), stringsAsFactors = F)
  colnames(plot_data) <- c("country", "polarization", "explanatory")
  plot_data$polarization <- as.numeric(plot_data$polarization)
  plot_data$explanatory  <- as.numeric(plot_data$explanatory)
  plot_data$abbrev       <- countrycode(plot_data$country, origin = "country.name", destination = "iso3c")
  beta  <- plot_data[plot_data$country == "United States", "polarization"] / plot_data[plot_data$country == "United States", "explanatory"]
  plot_data$US <- as.factor(as.numeric(plot_data$country == "United States"))
  
  p <- ggplot(plot_data, aes(explanatory, polarization)) +  # https://stackoverflow.com/questions/43417514/getting-rid-of-border-in-pdf-output-for-geom-label-for-ggplot2-in-r
    geom_hline(yintercept = 0, linetype = 'dotted', colour = "grey65", size = .8) +  geom_smooth(method = "lm", se = F, colour = "black") 
    
  if (min(c(rising, falling)) < 0){
    p <- p + geom_vline(xintercept = 0, linetype = 'dotted', colour = "grey65", size = .8) 
  }  
    
  model <- lm(polarization ~ explanatory, data = plot_data)
  out   <- spearman.test(slopes, c(rising, falling), alternative = "two.sided",
                       approximation = "exact")
  
  y_min <- min(slopes)
  y_max <- max(slopes) + (max(slopes) - min(slopes)) * .45
  
  x_spot  <- .15 * (max(c(rising, falling)) - min(c(rising, falling))) + min(c(rising, falling)) 
  y_spot  <- .95 * (y_max - y_min) + y_min
  y_spot2 <- .88 * (y_max - y_min) + y_min
  
  # Add pvalue labels
  cor_lab <- sprintf("rank cor: %.3f",     out$estimate)
  p_lab   <- sprintf("rank p-value: %.3f", out$p.value)
  p <- p  + annotate("text", x = x_spot, y = y_spot,  label = cor_lab, colour = "red4") +
            annotate("text", x = x_spot, y = y_spot2, label = p_lab,   colour = "red4") 
  
  p <- p + ylab("Affective polarization trend") + xlab(sprintf("%s trend", driver)) + geom_label(aes(label = abbrev, size = US), label.size = NA, fill = alpha(c("white"),0), show.legend = F) +
    scale_size_manual(values = c(3, 4.3)) # https://stackoverflow.com/questions/48557889/ggrepel-label-with-transparent-background-but-visible-font
  p <- p + theme_bw() +
    theme(panel.border     = element_blank(),
          panel.grid.major = element_blank(),
          panel.grid.minor = element_blank(),
          axis.line        = element_line(colour = "black"),
          legend.position  = c(.9, .92), # https://stackoverflow.com/questions/47584766/draw-a-box-around-a-legend-ggplot2
          legend.title = element_blank(),
          legend.spacing.y = unit(0, "mm"),
          plot.title       = element_text(hjust = 0.5, size = 16),
          axis.title.x     = element_text(size = 14),
          axis.title.y     = element_text(size = 14),
          axis.text.x      = element_text(size = 12),
          axis.text.y      = element_text(size = 12))
  ggsave(sprintf("build/descriptive/unity_%s.pdf", extension), height = 3.8, width = 5.5)
}


str_decap <- function(str) {
  
  lead      <- substring(str, 1, 1)
  follow    <- substring(str, 2)
  str_decap <- paste0(str_to_lower(lead), follow)
  
  return(str_decap)
}

### Execute
main()
